#include <cuda_runtime.h>
#include <cublas_v2.h>
#include "mex.h"


#pragma comment(lib,"cublas.lib")


/*
nlhs：输出参数数目   (Left-hand side)
plhs：指向输出参数的指针 

nrhs：输入参数数目
prhs: 指向输入参数的指针 

mxGetPr(prhs[0]);
mxGetM(prhs[0]);//获得矩阵行数
mxGetN(prhs[0]);//获得矩阵列数

plhs[0]=mxCreateDoubleMatrix(M,N,mxREAL); //创建输出矩阵
outData=mxGetPr(plhs[0]);
*/
void mexFunction(int nlhs, mxArray *plhs[],int nrhs, mxArray const *prhs[])
{
	if(nrhs==2&&nlhs==1){
		printf("input and output OK!!\n");
	}else {
		printf("error\n");
		return;
	}
	//C=A*B
	// A M*Q   B Q*N
	double *A=mxGetPr(prhs[0]);
	double *B=mxGetPr(prhs[1]);
	int M=mxGetM(prhs[0]);
	int Q1=mxGetN(prhs[0]);
	
	int Q2=mxGetM(prhs[1]);
	int N=mxGetN(prhs[1]);
	
	if(Q1!=Q2){printf("can't multi\n"); return;}
	int Q=Q1;
	plhs[0]=mxCreateDoubleMatrix(M,N,mxREAL);//输出矩阵维度M*N
	double *C=mxGetPr(plhs[0]);
	double *dev_A,*dev_B,*dev_C;
	cudaMalloc(&dev_A, sizeof(double)*M*Q);
	cudaMalloc(&dev_B, sizeof(double)*Q*N);
	cudaMalloc(&dev_C, sizeof(double)*M*N);
	cudaMemcpy(dev_A, A, sizeof(double)*M*Q, cudaMemcpyHostToDevice);
	cudaMemcpy(dev_B, B, sizeof(double)*Q*N, cudaMemcpyHostToDevice);
	
	cublasHandle_t handle;
	cublasCreate(&handle);
	double alpha = 1;
	double beta = 0;
	//matlab全都是按列存储，直接根据文档说明来就行了
	cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, Q, &alpha, dev_A, M, dev_B, Q, &beta, dev_C, M);
	
	cudaMemcpy(C, dev_C, sizeof(double)*M*N, cudaMemcpyDeviceToHost);
	
	cudaFree(dev_A);
	cudaFree(dev_B);
	cudaFree(dev_C);
	cublasDestroy(handle);
	
}